################################################################################
## Helper functions for policy learning with asymmetric counterfactual utilities
################################################################################
library(Matrix)




get_regret_oracle <- function(delta_tau, delta_pos, delta_soc, ug, ub) {

  if(ug < ub) {

    c0 <- case_when(
      delta_soc == 1 ~ -ug,
      delta_tau == 0 & delta_pos == 0 ~ -ub,
      delta_tau == 0 & delta_pos == 1 ~ -ug,
      delta_tau == 1 & delta_pos == 0 ~ -ug - ub,
      delta_tau == 1 & delta_pos == 1 ~ -2 * ug
    )
    c1 <- case_when(
      delta_soc == 1 ~ ug,
      delta_tau == 0 & delta_pos == 0 ~ ug,
      delta_tau == 0 & delta_pos == 1 ~ ub,
      delta_tau == 1 & delta_pos == 0 ~ 2 * ug,
      delta_tau == 1 & delta_pos == 1 ~ ug + ub
    )
    c_const <- case_when(
      delta_soc == 1 ~ 0,
      delta_tau == 0 & delta_pos == 0 ~ 0,
      delta_tau == 0 & delta_pos == 1 ~ ug - ub,
      delta_tau == 1 & delta_pos == 0 ~ 0,
      delta_tau == 1 & delta_pos == 1 ~ ug - ub
    )
  } else {
    c0 <- case_when(
      delta_soc == 0 ~ -ub,
      delta_tau == 1 & delta_pos == 0 ~ -ub,
      delta_tau == 1 & delta_pos == 1 ~ -ug,
      delta_tau == 0 & delta_pos == 0 ~ -2 * ub,
      delta_tau == 0 & delta_pos == 1 ~ -ug - ub
    )
    c1 <- case_when(
      delta_soc == 0 ~ ub,
      delta_tau == 1 & delta_pos == 0 ~ ug,
      delta_tau == 1 & delta_pos == 1 ~ ub,
      delta_tau == 0 & delta_pos == 0 ~ ub + ub,
      delta_tau == 0 & delta_pos == 1 ~ 2 * ub
    )
    c_const <- case_when(
      delta_soc == 0 ~ 0,
      delta_tau == 1 & delta_pos == 0 ~ 0,
      delta_tau == 1 & delta_pos == 1 ~ ug - ub,
      delta_tau == 0 & delta_pos == 0 ~ 0,
      delta_tau == 0 & delta_pos == 1 ~ ug - ub
    )
  }
  return(list(c1 = c1, c0 = c0, c_const = c_const))

}

compute_true_regret <- function(c0, c1, c_const, ug, ub, tau, e01) {

  return(-(ug * tau + (ug - ub) * e01))

}


compute_dr_prob_01 <- function(y, trt, pscore, mu1, mu0, pos_class, pol) {

  m0hat <- mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore))
  m1hat <- (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore))

  return(mean(pol * (m0hat + pos_class * (1 - m0hat - m1hat))))

}
compute_dr_prob_01_low <- function(y, trt, pscore, mu1, mu0, tau_pos, pol) {

  m0hat <- mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore))
  m1hat <- (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore))

  return(-mean(pol * (m1hat - m0hat) * (1 - tau_pos)))

}

compute_dr_prob_10 <- function(y, trt, pscore, mu1, mu0, pos_class, pol) {

  m0hat <- mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore))
  m1hat <- (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore))

  return(mean((1 - pol) * (m1hat + pos_class * (1 - m0hat - m1hat))))

}

compute_dr_prob_10_low <- function(y, trt, pscore, mu1, mu0, tau_pos, pol) {

  m0hat <- mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore))
  m1hat <- (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore))

  return(mean( (1 - pol) * (m1hat - m0hat) * tau_pos))

}

compute_dr_avg_outcome <- function(y, trt, pscore, mu1, mu0, pol) {

  m0hat <- mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore))
  m1hat <- (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore))

  return(mean((1 - pol) * m0hat) + mean(pol * m1hat))
}


#' Compute utilities with IPW
compute_ipw_utility <- function(c0, c1, c_const, y, trt, pscore, ...) {
  return(c0 * y * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore)) +
          c1 * y * trt / pscore / mean(trt / pscore) + c_const)
}

#' Compute utilities with DR (standardized pscores)
compute_dr_utility <- function(c0, c1, c_const, y, trt, pscore, mu1, mu0, ...) {
  n <- length(y)
  return(c0 * (mu0 + (y - mu0) * (1 - trt) / (1 - pscore) / mean((1 - trt) / (1 - pscore)))+
         c1 * (mu1 + (y - mu1) * trt / pscore / mean(trt / pscore)) +
         c_const)
}

#' Compute utilities with imputation
compute_impute_utility <- function(c0, c1, c_const, mu1, mu0, ...) {
  return(c0 * mu0 + c1 * mu1 + c_const)
}



